{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One-time download, uncomment the next cells to get the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#path = Config().data_path()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip -P {path}\n",
    "#! unzip -q -n {path}/summer2winter_yosemite.zip -d {path}\n",
    "#! rm {path}/summer2winter_yosemite.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Config().data_path()/'summer2winter_yosemite'\n",
    "path.ls()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See [this tutorial](https://docs.fast.ai/tutorial.itemlist.html) for a detailed walkthrough of how/why this custom `ItemList` was created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageTuple(ItemBase):\n",
    "    def __init__(self, img1, img2):\n",
    "        self.img1,self.img2 = img1,img2\n",
    "        self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]\n",
    "    \n",
    "    def apply_tfms(self, tfms, **kwargs):\n",
    "        self.img1 = self.img1.apply_tfms(tfms, **kwargs)\n",
    "        self.img2 = self.img2.apply_tfms(tfms, **kwargs)\n",
    "        return self\n",
    "    \n",
    "    def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TargetTupleList(ItemList):\n",
    "    def reconstruct(self, t:Tensor): \n",
    "        if len(t.size()) == 0: return t\n",
    "        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageTupleList(ImageList):\n",
    "    _label_cls=TargetTupleList\n",
    "    def __init__(self, items, itemsB=None, **kwargs):\n",
    "        self.itemsB = itemsB\n",
    "        super().__init__(items, **kwargs)\n",
    "    \n",
    "    def new(self, items, **kwargs):\n",
    "        return super().new(items, itemsB=self.itemsB, **kwargs)\n",
    "    \n",
    "    def get(self, i):\n",
    "        img1 = super().get(i)\n",
    "        fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]\n",
    "        return ImageTuple(img1, open_image(fn))\n",
    "    \n",
    "    def reconstruct(self, t:Tensor): \n",
    "        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))\n",
    "    \n",
    "    @classmethod\n",
    "    def from_folders(cls, path, folderA, folderB, **kwargs):\n",
    "        itemsB = ImageList.from_folder(path/folderB).items\n",
    "        res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)\n",
    "        res.path = path\n",
    "        return res\n",
    "    \n",
    "    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):\n",
    "        \"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method.\"\n",
    "        rows = int(math.sqrt(len(xs)))\n",
    "        fig, axs = plt.subplots(rows,rows,figsize=figsize)\n",
    "        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):\n",
    "            xs[i].to_one().show(ax=ax, **kwargs)\n",
    "        plt.tight_layout()\n",
    "\n",
    "    def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):\n",
    "        \"\"\"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\n",
    "        `kwargs` are passed to the show method.\"\"\"\n",
    "        figsize = ifnone(figsize, (12,3*len(xs)))\n",
    "        fig,axs = plt.subplots(len(xs), 2, figsize=figsize)\n",
    "        fig.suptitle('Ground truth / Predictions', weight='bold', size=14)\n",
    "        for i,(x,z) in enumerate(zip(xs,zs)):\n",
    "            x.to_one().show(ax=axs[i,0], **kwargs)\n",
    "            z.to_one().show(ax=axs[i,1], **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (ImageTupleList.from_folders(path, 'trainA', 'trainB')\n",
    "                      .split_none()\n",
    "                      .label_empty()\n",
    "                      .transform(get_transforms(), size=128)\n",
    "                      .databunch(bs=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.show_batch(rows=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the models that were introduced in the [cycleGAN paper](https://arxiv.org/abs/1703.10593)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):\n",
    "    return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),\n",
    "            norm_layer(ch_out), nn.ReLU(True)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True, \n",
    "                       pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n",
    "    layers = []\n",
    "    if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))\n",
    "    elif pad_mode == 'border':   layers.append(nn.ReplicationPad2d(pad))\n",
    "    p = pad if pad_mode == 'zeros' else 0\n",
    "    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)\n",
    "    if init:\n",
    "        init(conv.weight)\n",
    "        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n",
    "    layers += [conv, norm_layer(ch_out)]\n",
    "    if activ: layers.append(nn.ReLU(inplace=True))\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResnetBlock(nn.Module):\n",
    "    def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):\n",
    "        super().__init__()\n",
    "        assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'\n",
    "        norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n",
    "        layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)\n",
    "        if dropout != 0: layers.append(nn.Dropout(dropout))\n",
    "        layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)\n",
    "        self.conv_block = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x): return x + self.conv_block(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None, \n",
    "                     dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:\n",
    "    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n",
    "    bias = (norm_layer == nn.InstanceNorm2d)\n",
    "    layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)\n",
    "    for i in range(2):\n",
    "        layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)\n",
    "        n_ftrs *= 2\n",
    "    layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]\n",
    "    for i in range(2):\n",
    "        layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)\n",
    "        n_ftrs //= 2\n",
    "    layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet_generator(3, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, \n",
    "                 activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n",
    "    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)\n",
    "    if init:\n",
    "        init(conv.weight)\n",
    "        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n",
    "    layers = [conv]\n",
    "    if norm_layer is not None: layers.append(norm_layer(ch_out))\n",
    "    if activ: layers.append(nn.LeakyReLU(slope, inplace=True))\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discriminator(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:\n",
    "    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n",
    "    bias = (norm_layer == nn.InstanceNorm2d)\n",
    "    layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)\n",
    "    for i in range(n_layers-1):\n",
    "        new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs\n",
    "        layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)\n",
    "        n_ftrs = new_ftrs\n",
    "    new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs\n",
    "    layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)\n",
    "    layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))\n",
    "    if sigmoid: layers.append(nn.Sigmoid())\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We group two discriminators and two generators in a single model, then a `Callback` will take care of training them properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGAN(nn.Module):\n",
    "    \n",
    "    def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True, \n",
    "                 drop:float=0., norm_layer:nn.Module=None):\n",
    "        super().__init__()\n",
    "        self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n",
    "        self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n",
    "        self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n",
    "        self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n",
    "        #G_A: takes real input B and generates fake input A\n",
    "        #G_B: takes real input A and generates fake input B\n",
    "        #D_A: trained to make the difference between real input A and fake input A\n",
    "        #D_B: trained to make the difference between real input B and fake input B\n",
    "    \n",
    "    def forward(self, real_A, real_B):\n",
    "        fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)\n",
    "        if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)\n",
    "        idt_A, idt_B = self.G_A(real_A), self.G_B(real_B) #Needed for the identity loss during training.\n",
    "        return [fake_A, fake_B, idt_A, idt_B]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`AdaptiveLoss` is a wrapper around a PyTorch loss function to compare an output of any size with a single number (0. or 1.). It will generate a target with the same shape as the output. A discriminator returns a feature map, and we want it to predict zeros (or ones) for each feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdaptiveLoss(nn.Module):\n",
    "    def __init__(self, crit):\n",
    "        super().__init__()\n",
    "        self.crit = crit\n",
    "    \n",
    "    def forward(self, output, target:bool, **kwargs):\n",
    "        targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())\n",
    "        return self.crit(output, targ, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main loss used to train the generators. It has three parts:\n",
    "- the classic GAN loss: they must make the critics believe their images are real\n",
    "- identity loss: if they are given an image from the set they are trying to imitate, they should return the same thing\n",
    "- cycle loss: if an image from A goes through the generator that imitates B then through the generator that imitates A, it should be the same as the initial image. Same for B and switching the generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGanLoss(nn.Module):\n",
    "    \n",
    "    def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):\n",
    "        super().__init__()\n",
    "        self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt\n",
    "        self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)\n",
    "    \n",
    "    def set_input(self, input):\n",
    "        self.real_A,self.real_B = input\n",
    "    \n",
    "    def forward(self, output, target):\n",
    "        fake_A, fake_B, idt_A, idt_B = output\n",
    "        #Generators should return identity on the datasets they try to convert to\n",
    "        self.id_loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))\n",
    "        #Generators are trained to trick the discriminators so the following should be ones\n",
    "        self.gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)\n",
    "        #Cycle loss\n",
    "        self.cyc_loss  = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)\n",
    "        self.cyc_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)\n",
    "        return self.id_loss+self.gen_loss+self.cyc_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main callback to train a cycle GAN. The training loop will train the generators (so `learn.opt` is given those parameters) while the critics are trained by the callback during `on_batch_end`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANTrainer(LearnerCallback):\n",
    "    _order = -20 #Need to run before the Recorder\n",
    "    \n",
    "    def _set_trainable(self, D_A=False, D_B=False):\n",
    "        gen = (not D_A) and (not D_B)\n",
    "        requires_grad(self.learn.model.G_A, gen)\n",
    "        requires_grad(self.learn.model.G_B, gen)\n",
    "        requires_grad(self.learn.model.D_A, D_A)\n",
    "        requires_grad(self.learn.model.D_B, D_B)\n",
    "        if not gen:\n",
    "            self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom\n",
    "            self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta\n",
    "            self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom\n",
    "            self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta\n",
    "    \n",
    "    def on_train_begin(self, **kwargs):\n",
    "        self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B\n",
    "        self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B\n",
    "        self.crit = self.learn.loss_func.crit\n",
    "        if not getattr(self,'opt_G',None):\n",
    "            self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])\n",
    "        else: \n",
    "            self.opt_G.lr,self.opt_G.wd = self.opt.lr,self.opt.wd\n",
    "            self.opt_G.mom,self.opt_G.beta = self.opt.mom,self.opt.beta\n",
    "        if not getattr(self,'opt_D_A',None):\n",
    "            self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])\n",
    "        if not getattr(self,'opt_D_B',None):\n",
    "            self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])\n",
    "        self.learn.opt.opt = self.opt_G.opt\n",
    "        self._set_trainable()\n",
    "        self.id_smter,self.gen_smter,self.cyc_smter = SmoothenValue(0.98),SmoothenValue(0.98),SmoothenValue(0.98)\n",
    "        self.da_smter,self.db_smter = SmoothenValue(0.98),SmoothenValue(0.98)\n",
    "        self.recorder.add_metric_names(['id_loss', 'gen_loss', 'cyc_loss', 'D_A_loss', 'D_B_loss'])\n",
    "        \n",
    "    def on_batch_begin(self, last_input, **kwargs):\n",
    "        self.learn.loss_func.set_input(last_input)\n",
    "    \n",
    "    def on_backward_begin(self, **kwargs):\n",
    "        self.id_smter.add_value(self.loss_func.id_loss.detach().cpu())\n",
    "        self.gen_smter.add_value(self.loss_func.gen_loss.detach().cpu())\n",
    "        self.cyc_smter.add_value(self.loss_func.cyc_loss.detach().cpu())\n",
    "    \n",
    "    def on_batch_end(self, last_input, last_output, **kwargs):\n",
    "        self.G_A.zero_grad(); self.G_B.zero_grad()\n",
    "        fake_A, fake_B = last_output[0].detach(), last_output[1].detach()\n",
    "        real_A, real_B = last_input\n",
    "        self._set_trainable(D_A=True)\n",
    "        self.D_A.zero_grad()\n",
    "        loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))\n",
    "        self.da_smter.add_value(loss_D_A.detach().cpu())\n",
    "        loss_D_A.backward()\n",
    "        self.opt_D_A.step()\n",
    "        self._set_trainable(D_B=True)\n",
    "        self.D_B.zero_grad()\n",
    "        loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))\n",
    "        self.db_smter.add_value(loss_D_B.detach().cpu())\n",
    "        loss_D_B.backward()\n",
    "        self.opt_D_B.step()\n",
    "        self._set_trainable()\n",
    "        \n",
    "    def on_epoch_end(self, last_metrics, **kwargs):\n",
    "        return add_metrics(last_metrics, [s.smooth for s in [self.id_smter,self.gen_smter,self.cyc_smter,\n",
    "                                                             self.da_smter,self.db_smter]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cycle_gan = CycleGAN(3,3, gen_blocks=9)\n",
    "learn = Learner(data, cycle_gan, loss_func=CycleGanLoss(cycle_gan), opt_func=partial(optim.Adam, betas=(0.5,0.99)),\n",
    "               callback_fns=[CycleGANTrainer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(100, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('100fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = learn.load('100fit')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at some results using `Learner.show_results`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.show_results(ds_type=DatasetType.Train, rows=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.show_results(ds_type=DatasetType.Train, rows=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's go through all the images of the training set and find the ones that are the best converted (according to our critics) or the worst converted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(learn.data.train_ds.items),len(learn.data.train_ds.itemsB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batch(filenames, tfms, **kwargs):\n",
    "    samples = [open_image(fn) for fn in filenames]\n",
    "    for s in samples: s = s.apply_tfms(tfms, **kwargs)\n",
    "    batch = torch.stack([s.data for s in samples], 0).cuda()\n",
    "    return 2. * (batch - 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames = learn.data.train_ds.items[:8]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = get_batch(fnames, get_transforms()[1], size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.model.eval()\n",
    "tfms = get_transforms()[1]\n",
    "bs = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_losses(fnames, gen, crit, bs=16):\n",
    "    losses_in,losses_out = [],[]\n",
    "    with torch.no_grad():\n",
    "        for i in progress_bar(range(0, len(fnames), bs)):\n",
    "            xb = get_batch(fnames[i:i+bs], tfms, size=128)\n",
    "            fakes = gen(xb)\n",
    "            preds_in,preds_out = crit(xb),crit(fakes)\n",
    "            loss_in  = learn.loss_func.crit(preds_in, True,reduction='none')\n",
    "            loss_out = learn.loss_func.crit(preds_out,True,reduction='none')\n",
    "            losses_in.append(loss_in.view(loss_in.size(0),-1).mean(1))\n",
    "            losses_out.append(loss_out.view(loss_out.size(0),-1).mean(1))\n",
    "    return torch.cat(losses_in),torch.cat(losses_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_A = get_losses(data.train_ds.x.items, learn.model.G_B, learn.model.D_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_B = get_losses(data.train_ds.x.itemsB, learn.model.G_A, learn.model.D_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_best(fnames, losses, gen, n=8):\n",
    "    sort_idx = losses.argsort().cpu()\n",
    "    _,axs = plt.subplots(n//2, 4, figsize=(12,2*n))\n",
    "    xb = get_batch(fnames[sort_idx][:n], tfms, size=128)\n",
    "    with torch.no_grad():\n",
    "        fakes = gen(xb)\n",
    "    xb,fakes = (1+xb.cpu())/2,(1+fakes.cpu())/2\n",
    "    for i in range(n):\n",
    "        axs.flatten()[2*i].imshow(xb[i].permute(1,2,0))\n",
    "        axs.flatten()[2*i].axis('off')\n",
    "        axs.flatten()[2*i+1].imshow(fakes[i].permute(1,2,0))\n",
    "        axs.flatten()[2*i+1].set_title(losses[sort_idx][i].item())\n",
    "        axs.flatten()[2*i+1].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_best(data.train_ds.x.items, losses_A[1], learn.model.G_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_best(data.train_ds.x.itemsB, losses_B[1], learn.model.G_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_best(data.train_ds.x.items, losses_A[1]-losses_A[0], learn.model.G_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
